import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.distributions as td
from torchvision import datasets, transforms
from torchvision.utils import make_grid
from tqdm import tqdm, trange
from pbb.models import ProbNNet4l, trainPNNet
from pbb.bounds import PBBobj
from pbb import data


def runexp(Net_Width, name_data, objective, prior_type, model, sigma_prior, pmin, learning_rate=1, momentum = 0, 
learning_rate_prior=0, momentum_prior=0, delta=0.025, layers=9, delta_test=0.01, mc_samples=1000, 
samples_ensemble=100, kl_penalty=1, initial_lamb=6.0, train_epochs=100, prior_dist='gaussian', 
verbose=False, device='cuda', prior_epochs=0, dropout_prob=0.2, perc_train=1.0, verbose_test=True, 
perc_prior=0.2, batch_size=128,full_weight_bias_list= None):
    # this makes the initialised prior the same for all bounds
    torch.manual_seed(11)
    np.random.seed(4)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    loader_kargs = {'num_workers': 1,
                    'pin_memory': True} if torch.cuda.is_available() else {}

    train, test = data.loaddataset(name_data)
    rho_prior = math.log(math.exp(sigma_prior)-1.0)
    #net0 = NNet4l(dropout_prob=dropout_prob, device=device).to(device) 
    train_loader, _, _, _, _, _ = data.loadbatches(
            train, test, loader_kargs, batch_size, prior=False, perc_train=perc_train, perc_prior=perc_prior)
    
    posterior_n_size = batch_size#len(train_loader.dataset)
    bound_n_size = batch_size#len(val_bound.dataset)
    toolarge = False
    classes = len(train_loader.dataset.classes)
    net = ProbNNet4l(rho_prior, prior_dist=prior_dist,
                        device=device, init_net=full_weight_bias_list,Net_Width= Net_Width).to(device)
    bound = PBBobj(objective, pmin, classes, delta,
                    delta_test, mc_samples, kl_penalty, device, n_posterior = posterior_n_size, n_bound=bound_n_size)
    optimizer_lambda = None
    lambda_var = None
    optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=momentum)
    
    #ratio
    list_of_layer1_bias_mu_ratio = [] 
    list_of_layer1_bias_rho_ratio = [] 
    list_of_layer1_weight_mu_ratio = [] 
    list_of_layer1_weight_rho_ratio= [] 
    list_of_layer2_bias_mu_ratio = [] 
    list_of_layer2_bias_rho_ratio= [] 
    list_of_layer2_weight_mu_ratio = [] 
    list_of_layer2_weight_rho_ratio= []     
    list_of_layer3_bias_mu_ratio = [] 
    list_of_layer3_bias_rho_ratio= [] 
    list_of_layer3_weight_mu_ratio = [] 
    list_of_layer3_weight_rho_ratio= [] 
    list_of_layer4_bias_mu_ratio = [] 
    list_of_layer4_bias_rho_ratio= [] 
    list_of_layer4_weight_mu_ratio = [] 
    list_of_layer4_weight_rho_ratio= [] 
    
    #combined_ratio 
    list_of_layer1_rho_ratio = [] 
    list_of_layer1_mu_ratio = [] 
    list_of_layer2_rho_ratio = [] 
    list_of_layer2_mu_ratio = [] 
    list_of_layer3_rho_ratio = [] 
    list_of_layer3_mu_ratio= [] 
    list_of_layer4_rho_ratio = [] 
    list_of_layer4_mu_ratio = []
    
    for epoch in trange(train_epochs):
        trainPNNet(net, optimizer, bound, epoch, train_loader, lambda_var, optimizer_lambda, verbose)
        if verbose_test and ((epoch+1) % 100 == 0):
            #ratio append
            list_of_layer1_bias_mu_ratio.append((torch.norm(net.l1.bias.mu - net.l1.bias_prior.mu) / torch.norm(net.l1.bias_prior.mu)).cpu().detach().numpy())
            list_of_layer1_bias_rho_ratio.append((torch.norm(net.l1.bias.rho - net.l1.bias_prior.rho) / torch.norm(net.l1.bias_prior.rho)).cpu().detach().numpy())
            list_of_layer1_weight_mu_ratio.append((torch.norm(net.l1.weight.mu - net.l1.weight_prior.mu) / torch.norm(net.l1.weight_prior.mu)).cpu().detach().numpy())
            list_of_layer1_weight_rho_ratio.append((torch.norm(net.l1.weight.rho - net.l1.weight_prior.rho) / torch.norm(net.l1.weight_prior.rho)).cpu().detach().numpy())
            list_of_layer2_bias_mu_ratio.append((torch.norm(net.l2.bias.mu - net.l2.bias_prior.mu) / torch.norm(net.l2.bias_prior.mu)).cpu().detach().numpy())
            list_of_layer2_bias_rho_ratio.append((torch.norm(net.l2.bias.rho - net.l2.bias_prior.rho) / torch.norm(net.l2.bias_prior.rho)).cpu().detach().numpy())
            list_of_layer2_weight_mu_ratio.append((torch.norm(net.l2.weight.mu - net.l2.weight_prior.mu) / torch.norm(net.l2.weight_prior.mu)).cpu().detach().numpy())
            list_of_layer2_weight_rho_ratio.append((torch.norm(net.l2.weight.rho - net.l2.weight_prior.rho) / torch.norm(net.l2.weight_prior.rho)).cpu().detach().numpy())
            list_of_layer3_bias_mu_ratio.append((torch.norm(net.l3.bias.mu - net.l3.bias_prior.mu) / torch.norm(net.l3.bias_prior.mu)).cpu().detach().numpy())
            list_of_layer3_bias_rho_ratio.append((torch.norm(net.l3.bias.rho - net.l3.bias_prior.rho) / torch.norm(net.l3.bias_prior.rho)).cpu().detach().numpy())
            list_of_layer3_weight_mu_ratio.append((torch.norm(net.l3.weight.mu - net.l3.weight_prior.mu) / torch.norm(net.l3.weight_prior.mu)).cpu().detach().numpy())
            list_of_layer3_weight_rho_ratio.append((torch.norm(net.l3.weight.rho - net.l3.weight_prior.rho) / torch.norm(net.l3.weight_prior.rho)).cpu().detach().numpy())            
            list_of_layer4_bias_mu_ratio.append((torch.norm(net.l4.bias.mu - net.l4.bias_prior.mu) / torch.norm(net.l4.bias_prior.mu)).cpu().detach().numpy())
            list_of_layer4_bias_rho_ratio.append((torch.norm(net.l4.bias.rho - net.l4.bias_prior.rho) / torch.norm(net.l4.bias_prior.rho)).cpu().detach().numpy())
            list_of_layer4_weight_mu_ratio.append((torch.norm(net.l4.weight.mu - net.l4.weight_prior.mu) / torch.norm(net.l4.weight_prior.mu)).cpu().detach().numpy())
            list_of_layer4_weight_rho_ratio.append((torch.norm(net.l4.weight.rho - net.l4.weight_prior.rho) / torch.norm(net.l4.weight_prior.rho)).cpu().detach().numpy())         
 
            #combined ratio append
            list_of_layer1_rho_ratio.append((torch.norm(torch.cat([net.l1.weight.rho.flatten(),net.l1.bias.rho.flatten()]) - torch.cat([net.l1.weight_prior.rho.flatten(),net.l1.bias_prior.rho.flatten()]))/ (torch.norm(torch.cat([net.l1.weight_prior.rho.flatten(),net.l1.bias_prior.rho.flatten()])))).cpu().detach().numpy())
            list_of_layer1_mu_ratio.append((torch.norm(torch.cat([net.l1.weight.mu.flatten(),net.l1.bias.mu.flatten()]) - torch.cat([net.l1.weight_prior.mu.flatten(),net.l1.bias_prior.mu.flatten()]))/ (torch.norm(torch.cat([net.l1.weight_prior.mu.flatten(),net.l1.bias_prior.mu.flatten()])))).cpu().detach().numpy())
            list_of_layer2_rho_ratio.append((torch.norm(torch.cat([net.l2.weight.rho.flatten(),net.l2.bias.rho.flatten()]) - torch.cat([net.l2.weight_prior.rho.flatten(),net.l2.bias_prior.rho.flatten()]))/ (torch.norm(torch.cat([net.l2.weight_prior.rho.flatten(),net.l2.bias_prior.rho.flatten()])))).cpu().detach().numpy())
            list_of_layer2_mu_ratio.append((torch.norm(torch.cat([net.l2.weight.mu.flatten(),net.l2.bias.mu.flatten()]) - torch.cat([net.l2.weight_prior.mu.flatten(),net.l2.bias_prior.mu.flatten()]))/ (torch.norm(torch.cat([net.l2.weight_prior.mu.flatten(),net.l2.bias_prior.mu.flatten()])))).cpu().detach().numpy())
            list_of_layer3_rho_ratio.append((torch.norm(torch.cat([net.l3.weight.rho.flatten(),net.l3.bias.rho.flatten()]) - torch.cat([net.l3.weight_prior.rho.flatten(),net.l3.bias_prior.rho.flatten()]))/ (torch.norm(torch.cat([net.l3.weight_prior.rho.flatten(),net.l3.bias_prior.rho.flatten()])))).cpu().detach().numpy())
            list_of_layer3_mu_ratio.append((torch.norm(torch.cat([net.l3.weight.mu.flatten(),net.l3.bias.mu.flatten()]) - torch.cat([net.l3.weight_prior.mu.flatten(),net.l3.bias_prior.mu.flatten()]))/ (torch.norm(torch.cat([net.l3.weight_prior.mu.flatten(),net.l3.bias_prior.mu.flatten()])))).cpu().detach().numpy())
            list_of_layer4_rho_ratio.append((torch.norm(torch.cat([net.l4.weight.rho.flatten(),net.l4.bias.rho.flatten()]) - torch.cat([net.l4.weight_prior.rho.flatten(),net.l4.bias_prior.rho.flatten()]))/ (torch.norm(torch.cat([net.l4.weight_prior.rho.flatten(),net.l4.bias_prior.rho.flatten()])))).cpu().detach().numpy())
            list_of_layer4_mu_ratio.append((torch.norm(torch.cat([net.l4.weight.mu.flatten(),net.l4.bias.mu.flatten()]) - torch.cat([net.l4.weight_prior.mu.flatten(),net.l4.bias_prior.mu.flatten()]))/ (torch.norm(torch.cat([net.l4.weight_prior.mu.flatten(),net.l4.bias_prior.mu.flatten()])))).cpu().detach().numpy())
            
    
    
    ratio_layer_one = [list_of_layer1_bias_mu_ratio, list_of_layer1_bias_rho_ratio, list_of_layer1_weight_mu_ratio, list_of_layer1_weight_rho_ratio]
    ratio_layer_two = [list_of_layer2_bias_mu_ratio, list_of_layer2_bias_rho_ratio, list_of_layer2_weight_mu_ratio, list_of_layer2_weight_rho_ratio]
    ratio_layer_three = [list_of_layer3_bias_mu_ratio, list_of_layer3_bias_rho_ratio, list_of_layer3_weight_mu_ratio, list_of_layer3_weight_rho_ratio]
    ratio_layer_four = [list_of_layer4_bias_mu_ratio, list_of_layer4_bias_rho_ratio, list_of_layer4_weight_mu_ratio, list_of_layer4_weight_rho_ratio]
    ratio = [ratio_layer_one,ratio_layer_two,ratio_layer_three,ratio_layer_four] 
    combined_ratio_layer_one = [list_of_layer1_rho_ratio,list_of_layer1_mu_ratio]
    combined_ratio_layer_two = [list_of_layer2_rho_ratio,list_of_layer2_mu_ratio]
    combined_ratio_layer_three = [list_of_layer3_rho_ratio,list_of_layer3_mu_ratio]
    combined_ratio_layer_four = [list_of_layer4_rho_ratio,list_of_layer4_mu_ratio]
    combined_layers = [combined_ratio_layer_one,combined_ratio_layer_two,combined_ratio_layer_three,combined_ratio_layer_four]
    layers = []
    meta_date = []      
    diff = []    

    return combined_layers, ratio, diff, layers, meta_date

def count_parameters(model): 
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
